{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e1c27274",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['S数字:3929=15717571757175716E',\n",
       " 'S小写:3616=一四四六五四四六五四四六五四四六四E',\n",
       " 'S大写:6438=贰伍柒伍肆伍柒伍肆伍柒伍肆伍柒伍贰E',\n",
       " 'S字母:5966=weiyyeiyyeiyyeiyrE',\n",
       " 'S小写:7716=三〇八六七〇八六七〇八六七〇八六四E',\n",
       " 'S大写:7307=贰玖贰叁零玖贰叁零玖贰叁零玖贰贰捌E',\n",
       " 'S数字:9302=37211721172117208E',\n",
       " 'S字母:7822=eqwoqqwoqqwoqqwiiE',\n",
       " 'S小写:7413=二九六五四九六五四九六五四九六五二E',\n",
       " 'S小写:6266=二五〇六六五〇六六五〇六六五〇六四E']"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "class Tokenizer:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vocab = {\n",
    "            'mark': list('PSEU'),\n",
    "            'number': list('0123456789'),\n",
    "            'letter': list('pqwertyuio'),\n",
    "            'chinese_lower': list('〇一二三四五六七八九'),\n",
    "            'chinese_upper': list('零壹贰叁肆伍陆柒捌玖'),\n",
    "            'other': list('数字大写小母:=_'),\n",
    "        }\n",
    "\n",
    "        self.decoder = [j for i in self.vocab.values() for j in i]\n",
    "        self.encoder = {j: i for i, j in enumerate(self.decoder)}\n",
    "\n",
    "        self.label = {\n",
    "            'number': 0,\n",
    "            'letter': 1,\n",
    "            'chinese_lower': 2,\n",
    "            'chinese_upper': 3\n",
    "        }\n",
    "        self.prefix = ['数字', '字母', '小写', '大写']\n",
    "\n",
    "    def decode(self, x):\n",
    "        return ''.join([self.decoder[i] for i in x])\n",
    "\n",
    "    def get_data(self, prefix):\n",
    "        #生成问题和答案\n",
    "        question = random.randint(1000, 9999)\n",
    "        answer = int(str(question) * 4) * 4\n",
    "        #answer = question**8\n",
    "        \n",
    "        question = list(str(question))\n",
    "        answer = list(str(answer))\n",
    "\n",
    "        #随机label\n",
    "        label = random.choice(list(self.label.keys()))\n",
    "\n",
    "        #根据label替换答案成其他字符集\n",
    "        answer = [self.vocab[label][int(i)] for i in answer]\n",
    "\n",
    "        #label转数字\n",
    "        label = self.label[label]\n",
    "\n",
    "        #组合问题和答案\n",
    "        if prefix:\n",
    "            prefix = list(self.prefix[label])\n",
    "        else:\n",
    "            prefix = list('__')\n",
    "        token = prefix + [':'] + question + ['='] + answer\n",
    "\n",
    "        #编码\n",
    "        token = [self.encoder[i] for i in token]\n",
    "        token = [self.encoder['S']] + token + [self.encoder['E']]\n",
    "\n",
    "        return label, token\n",
    "\n",
    "    def get_batch_data(self, prefix):\n",
    "        data = [self.get_data(prefix=prefix) for _ in range(64)]\n",
    "\n",
    "        label = [i[0] for i in data]\n",
    "        token = [i[1] for i in data]\n",
    "\n",
    "        return label, *self.batch_pad(token=token)\n",
    "\n",
    "    def batch_pad(self, text=None, token=None):\n",
    "        if text:\n",
    "            #编码\n",
    "            token = [[self.encoder[j] for j in i] for i in text]\n",
    "\n",
    "        lens = max([len(i) for i in token])\n",
    "\n",
    "        input_ids = []\n",
    "        attention_mask = []\n",
    "        for i in token:\n",
    "            attention_mask.append([1] * len(i) + [0] * (lens - len(i)))\n",
    "            input_ids.append(i + [self.encoder['P']] * (lens - len(i)))\n",
    "\n",
    "        return input_ids, attention_mask\n",
    "\n",
    "\n",
    "tokenizer = Tokenizer()\n",
    "\n",
    "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e242caa8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "17208b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelGEN(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        from transformers import GPT2Config, GPT2Model\n",
    "\n",
    "        self.config = GPT2Config(bos_token_id=tokenizer.encoder['S'],\n",
    "                                 eos_token_id=tokenizer.encoder['E'],\n",
    "                                 n_embd=64,\n",
    "                                 n_head=4,\n",
    "                                 n_layer=4,\n",
    "                                 n_positions=128,\n",
    "                                 vocab_size=len(tokenizer.decoder))\n",
    "\n",
    "        self.feature = GPT2Model(self.config)\n",
    "\n",
    "        self.fc_out = torch.nn.Linear(64, self.config.vocab_size, bias=False)\n",
    "\n",
    "        self.to(device)\n",
    "        self.train()\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        out = self.feature(input_ids=input_ids,\n",
    "                           attention_mask=attention_mask).last_hidden_state\n",
    "\n",
    "        return self.fc_out(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a23b28dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelCLS(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        from transformers import BertConfig, BertModel\n",
    "\n",
    "        self.config = BertConfig(hidden_size=64,\n",
    "                                 intermediate_size=64,\n",
    "                                 max_position_embeddings=128,\n",
    "                                 num_attention_heads=4,\n",
    "                                 num_hidden_layers=4,\n",
    "                                 vocab_size=len(tokenizer.decoder))\n",
    "\n",
    "        self.feature = BertModel(self.config)\n",
    "\n",
    "        self.fc_out = torch.nn.Sequential(torch.nn.Dropout(p=0.1),\n",
    "                                          torch.nn.Linear(64, 4))\n",
    "\n",
    "        self.to(device)\n",
    "        self.train()\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        out = self.feature(input_ids=input_ids,\n",
    "                           attention_mask=attention_mask).pooler_output\n",
    "\n",
    "        return self.fc_out(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c6224a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelPPO(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, model_gen):\n",
    "        super().__init__()\n",
    "        self.model_gen = model_gen\n",
    "        self.v_head = torch.nn.Sequential(torch.nn.Dropout(0.1),\n",
    "                                          torch.nn.Linear(64, 1))\n",
    "\n",
    "        self.to(device)\n",
    "        self.train()\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        last_hidden_state = self.model_gen.feature(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            output_hidden_states=True).last_hidden_state\n",
    "\n",
    "        logits = self.model_gen.fc_out(last_hidden_state)\n",
    "        value = self.v_head(last_hidden_state).squeeze(-1)\n",
    "\n",
    "        return logits, value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a13f60dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "generater = None\n",
    "\n",
    "\n",
    "def generate(model_gen, input_ids):\n",
    "    global generater\n",
    "    if not generater:\n",
    "        #包装类,用于生成\n",
    "        from transformers import GPT2LMHeadModel\n",
    "        generater = GPT2LMHeadModel(model_gen.config)\n",
    "        generater.transformer = model_gen.feature\n",
    "        generater.lm_head = model_gen.fc_out\n",
    "        generater.to(device)\n",
    "\n",
    "    return generater.generate(input_ids=input_ids,\n",
    "                              min_length=-1,\n",
    "                              top_k=0.0,\n",
    "                              top_p=1.0,\n",
    "                              do_sample=True,\n",
    "                              pad_token_id=tokenizer.encoder['P'],\n",
    "                              max_new_tokens=25,\n",
    "                              eos_token_id=tokenizer.encoder['E'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
